# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Modified from https://github.com/facebookresearch/PoseDiffusion/blob/main/pose_diffusion/models/gaussian_diffuser.py

import math
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F
import numpy as np
from einops import rearrange, reduce
from utils.utils import calc_vos_simple
# constants

ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"])

# helpers functions


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = (
        torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    )
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        timesteps=100,
        sampling_timesteps=None,
        beta_1=0.0001,
        beta_T=0.1,
        loss_type="AtLocPlus",
        objective="pred_noise",
        beta_schedule="custom",
        p2_loss_weight_gamma=0.0,
        p2_loss_weight_k=1,
    ):
        super().__init__()

        self.objective = objective

        assert objective in {
            "pred_noise",
            "pred_x0",
        }, "objective must be either pred_noise (predict noise) \
            or pred_x0 (predict image start)"

        self.timesteps = timesteps
        self.sampling_timesteps = sampling_timesteps
        self.beta_1 = beta_1
        self.beta_T = beta_T
        self.loss_type = loss_type
        self.objective = objective
        self.beta_schedule = beta_schedule
        self.p2_loss_weight_gamma = p2_loss_weight_gamma
        self.p2_loss_weight_k = p2_loss_weight_k

        self.init_diff_hyper(
            self.timesteps,
            self.sampling_timesteps,
            self.beta_1,
            self.beta_T,
            self.loss_type,
            self.objective,
            self.beta_schedule,
            self.p2_loss_weight_gamma,
            self.p2_loss_weight_k,
        )

    def init_diff_hyper(
        self,
        timesteps,
        sampling_timesteps,
        beta_1,
        beta_T,
        loss_type,
        objective,
        beta_schedule,
        p2_loss_weight_gamma,
        p2_loss_weight_k,
    ):
        if beta_schedule == "linear":
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == "cosine":
            betas = cosine_beta_schedule(timesteps)
        elif beta_schedule == "custom":
            betas = torch.linspace(
                beta_1, beta_T, timesteps, dtype=torch.float64
            )
        else:
            raise ValueError(f"unknown beta schedule {beta_schedule}")

        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        (timesteps,) = betas.shape
        self.num_timesteps = int(timesteps)
        self.loss_type = loss_type

        # sampling related parameters
        self.sampling_timesteps = default(
            sampling_timesteps, timesteps
        )  # default num sampling timesteps to number of timesteps at training

        assert self.sampling_timesteps <= timesteps

        # helper function to register buffer from float64 to float32
        register_buffer = lambda name, val: self.register_buffer(
            name, val.to(torch.float32)
        )

        register_buffer("betas", betas)
        register_buffer("alphas_cumprod", alphas_cumprod)
        register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        register_buffer(
            "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
        )
        register_buffer(
            "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)
        )
        register_buffer(
            "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)
        )
        register_buffer(
            "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
        )

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        register_buffer("posterior_variance", posterior_variance)

        # below: log calculation clipped because the posterior variance is 0
        # at the beginning of the diffusion chain
        register_buffer(
            "posterior_log_variance_clipped",
            torch.log(posterior_variance.clamp(min=1e-20)),
        )
        register_buffer(
            "posterior_mean_coef1",
            betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
        )
        register_buffer(
            "posterior_mean_coef2",
            (1.0 - alphas_cumprod_prev)
            * torch.sqrt(alphas)
            / (1.0 - alphas_cumprod),
        )

        # calculate p2 reweighting
        register_buffer(
            "p2_loss_weight",
            (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
            ** -p2_loss_weight_gamma,
        )

        # 定义损失函数
        if self.loss_type == 'l1':
            self.loss_fn = nn.L1Loss()
        elif self.loss_type == 'AtLoc':
            self.loss_fn = AtLocCriterion()
        else:
            self.loss_fn = AtLocPlusCriterion()

    # helper functions
    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
        ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )

        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        return (
            posterior_mean,
            posterior_variance,
            posterior_log_variance_clipped,
        )

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )

    def model_predictions(self, x, t, z, x_self_cond=None):
        model_output = self.model(x, t, z)
        # print(self.objective)
        if self.objective == "pred_noise":
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, model_output)

        elif self.objective == "pred_x0":
            pred_noise = self.predict_noise_from_start(x, t, model_output)
            x_start = model_output

        return ModelPrediction(pred_noise, x_start)

    def p_mean_variance(
        self,
        x: torch.Tensor,  # B x N_x x dim
        t: int,
        z: torch.Tensor,
        x_self_cond=None,
        clip_denoised=False,
    ):
        preds = self.model_predictions(x, t, z)

        x_start = preds.pred_x_start

        if clip_denoised:
            raise NotImplementedError(
                "We don't clip the output because \
                    pose does not have a clear bound."
            )

        (
            model_mean,
            posterior_variance,
            posterior_log_variance,
        ) = self.q_posterior(x_start=x_start, x_t=x, t=t)

        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.no_grad()
    def p_sample(
        self,
        x: torch.Tensor,  # B x N_x x dim
        t: int,
        z: torch.Tensor,
        x_self_cond=None,
        clip_denoised=False,
        cond_fn=None,
        cond_start_step=0,
    ):
        b, *_, device = *x.shape, x.device
        batched_times = torch.full(
            (x.shape[0],), t, device=x.device, dtype=torch.long
        )
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(
            x=x,
            t=batched_times,
            z=z,
            x_self_cond=x_self_cond,
            clip_denoised=clip_denoised,
        )

        if cond_fn is not None and t < cond_start_step:
            model_mean = cond_fn(model_mean, t)
            noise = 0.0
        else:
            noise = torch.randn_like(x) if t > 0 else 0.0  # no noise if t == 0

        pred = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred, x_start

    @torch.no_grad()
    def p_sample_loop(
        self,
        shape,
        z: torch.Tensor,
        cond_fn=None,
        cond_start_step=0,
    ):
        batch, device = shape[0], self.betas.device

        # Init here
        pose = torch.randn(shape, device=device)

        x_start = None

        pose_process = []
        pose_process.append(pose.unsqueeze(0))

        for t in reversed(range(0, self.num_timesteps)):
            pose, _ = self.p_sample(
                x=pose,
                t=t,
                z=z,
                cond_fn=cond_fn,
                cond_start_step=cond_start_step,
            )
            pose_process.append(pose.unsqueeze(0))

        return pose, torch.cat(pose_process)

    # ddpm generation
    @torch.no_grad()
    def sample(self, shape, z, cond_fn=None, cond_start_step=0):
        # TODO: add more variants
        sample_fn = self.p_sample_loop
        return sample_fn(
            shape, z=z, cond_fn=cond_fn, cond_start_step=cond_start_step
        )

    # ddim generation
    @torch.no_grad()
    def ddim_sample(self, shape, z, sampling_timesteps=20):
        batch_size = shape[0]
        device = self.betas.device
        img = torch.randn(shape, device=device)

        times = torch.linspace(-1, self.num_timesteps - 1,
                               steps=sampling_timesteps + 1)  # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:]))

        for time, time_next in time_pairs:
            time_cond = torch.full((batch_size,), time, device=device, dtype=torch.long)
            pred = self.model_predictions(img, time_cond, z)
            pred_noise = pred.pred_noise
            # x_start = pred.pred_x_start

            if time_next < 0:
                img = x_start
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]

            eta = 0

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(img)

            x_start = (img - torch.sqrt((1. - alpha)) * pred_noise) / torch.sqrt(alpha)

            img = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise

        return img, img

    def p_losses(
        self,
        x_start,
        t,
        z=None,
        noise=None,
    ):
        noise = default(noise, lambda: torch.randn_like(x_start))
        # noise sample
        x = self.q_sample(x_start=x_start, t=t, noise=noise)

        model_out = self.model(x, t, z)

        if self.objective == "pred_noise":
            target = noise
            x_0_pred = self.predict_start_from_noise(x, t, model_out)
        elif self.objective == "pred_x0":
            target = x_start
            x_0_pred = model_out
        else:
            raise ValueError(f"unknown objective {self.objective}")

        # loss = self.loss_fn(model_out, target, reduction="none")

        # loss = reduce(loss, "b ... -> b (...)", "mean")
        # loss = loss * extract(self.p2_loss_weight, t, loss.shape)

        loss = self.loss_fn(model_out, target)

        return {
            "diffloss": loss,
            "noise": noise,
            "x_0_pred": x_0_pred,
            "x_t": x,
            "t": t,
        }

    def forward(self, pose, z=None, *args, **kwargs):
        b = len(pose)
        t = torch.randint(
            0, self.num_timesteps, (b,), device=pose.device
        ).long()
        return self.p_losses(pose, t, z=z, *args, **kwargs)


class AtLocCriterion(nn.Module):
    def __init__(self, loss_fn=nn.L1Loss()):
        super(AtLocCriterion, self).__init__()
        self.loss_fn = loss_fn

    def forward(self, pred, targ):
        s = pred.size()
        loss = self.loss_fn(pred.view(-1, *s[2:]), targ.view(-1, *s[2:]))
        return loss


class AtLocPlusCriterion(nn.Module):
    def __init__(self, loss_fn=nn.L1Loss()):
        super(AtLocPlusCriterion, self).__init__()
        self.loss_fn = loss_fn

    def forward(self, pred, targ):
        # absolute pose loss
        s = pred.size()
        abs_loss = self.loss_fn(pred.view(-1, *s[2:]), targ.view(-1, *s[2:]))
        # get the VOs
        pred_vos = calc_vos_simple(pred)
        targ_vos = calc_vos_simple(targ)

        # VO loss
        s = pred_vos.size()
        vo_loss = self.loss_fn(pred_vos.view(-1, *s[2:]), targ_vos.view(-1, *s[2:]))

        # total loss
        loss = abs_loss + vo_loss

        return loss